{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/llmu/Classify_Endpoint.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "xYo_6bTr21nz"
      },
      "source": [
        "# The Classify Endpoint\n",
        "\n",
        "In this lab, we'll learn how to use Cohere's Classify cohere endpoint. This codelab accompanies the [Classify endpoint lesson](https://docs.cohere.com/docs/classify-endpoint/) of LLM University."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "jlINcHlkFKXw"
      },
      "source": [
        "# Setting up"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "1ys2CHEgurfe"
      },
      "source": [
        "The first step is to install the Cohere Python SDK. Next, create an API key, which you can generate from the Cohere [dashboard](https://os.cohere.ai/register) or [CLI tool](https://docs.cohere.ai/cli-key)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "QdEURifRRUgy"
      },
      "outputs": [],
      "source": [
        "# Install the libraries\n",
        "# TODO: upgrade to \"cohere>5\"\n",
        "! pip install cohere altair umap-learn > /dev/null"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Import the libraries\n",
        "import cohere\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import altair as alt\n",
        "import textwrap as tr\n",
        "\n",
        "# Setup the Cohere client\n",
        "co = cohere.Client(\"COHERE_API_KEY\") # Get your API key here: https://dashboard.cohere.com/api-keys"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "7ha9EyO_RunK"
      },
      "source": [
        "# Classifying Text"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "yX8fQJ5LAcfz"
      },
      "source": [
        "Cohere’s Classify endpoint makes it easy to take a list of texts and predict their categories, or classes. A typical machine learning model requires many training examples to perform text classification, but with the Classify endpoint, you can get started with as few as 5 examples per class."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "DtHJ02d7Rz8q"
      },
      "source": [
        "### Sentiment Analysis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "-lSi_UmQEfy_"
      },
      "outputs": [],
      "source": [
        "# Create the training examples for the classifier\n",
        "\n",
        "from cohere import ClassifyExample\n",
        "\n",
        "examples = [ClassifyExample(text=\"I’m so proud of you\", label=\"positive\"), \n",
        "            ClassifyExample(text=\"What a great time to be alive\", label=\"positive\"), \n",
        "            ClassifyExample(text=\"That’s awesome work\", label=\"positive\"), \n",
        "            ClassifyExample(text=\"The service was amazing\", label=\"positive\"), \n",
        "            ClassifyExample(text=\"I love my family\", label=\"positive\"), \n",
        "            ClassifyExample(text=\"They don't care about me\", label=\"negative\"), \n",
        "            ClassifyExample(text=\"I hate this place\", label=\"negative\"), \n",
        "            ClassifyExample(text=\"The most ridiculous thing I've ever heard\", label=\"negative\"), \n",
        "            ClassifyExample(text=\"I am really frustrated\", label=\"negative\"), \n",
        "            ClassifyExample(text=\"This is so unfair\", label=\"negative\"),\n",
        "            ClassifyExample(text=\"This made me think\", label=\"neutral\"), \n",
        "            ClassifyExample(text=\"The good old days\", label=\"neutral\"), \n",
        "            ClassifyExample(text=\"What's the difference\", label=\"neutral\"), \n",
        "            ClassifyExample(text=\"You can't ignore this\", label=\"neutral\"), \n",
        "            ClassifyExample(text=\"That's how I see it\", label=\"neutral\")]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "QpC1Z4xEEWs0"
      },
      "outputs": [],
      "source": [
        "# Enter the inputs to be classified\n",
        "inputs=[\"Hello, world! What a beautiful day\",\n",
        "        \"It was a great time with great people\",\n",
        "        \"Great place to work\",\n",
        "        \"That was a wonderful evening\",\n",
        "        \"Maybe this is why\",\n",
        "        \"Let's start again\",\n",
        "        \"That's how I see it\",\n",
        "        \"These are all facts\",\n",
        "        \"This is the worst thing\",\n",
        "        \"I cannot stand this any longer\",\n",
        "        \"This is really annoying\",\n",
        "        \"I am just plain fed up\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "etTr200IRszm"
      },
      "outputs": [],
      "source": [
        "# A function that classifies a list of inputs given the examples\n",
        "def classify_text(inputs, examples):\n",
        "  \"\"\"\n",
        "  Classify a list of input texts\n",
        "  Arguments:\n",
        "    inputs(list[str]): a list of input texts to be classified\n",
        "    examples(list[Example]): a list of example texts and class labels\n",
        "  Returns:\n",
        "    classifications(list): each result contains the text, labels, and conf values\n",
        "  \"\"\"\n",
        "  # Classify text by calling the Classify endpoint\n",
        "  response = co.classify(\n",
        "    model='embed-english-v3.0',\n",
        "    inputs=inputs,\n",
        "    examples=examples)\n",
        "  \n",
        "  classifications = response.classifications\n",
        "  \n",
        "  return classifications"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kOdL3U0jRswU",
        "outputId": "93ea6111-ef75-4593-971c-b20b5dfb3d22"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Input: Hello, world! What a beautiful day\n",
            "Prediction: positive\n",
            "Confidence: 0.40\n",
            "----------\n",
            "Input: It was a great time with great people\n",
            "Prediction: positive\n",
            "Confidence: 0.49\n",
            "----------\n",
            "Input: Great place to work\n",
            "Prediction: positive\n",
            "Confidence: 0.50\n",
            "----------\n",
            "Input: That was a wonderful evening\n",
            "Prediction: positive\n",
            "Confidence: 0.48\n",
            "----------\n",
            "Input: Maybe this is why\n",
            "Prediction: neutral\n",
            "Confidence: 0.45\n",
            "----------\n",
            "Input: Let's start again\n",
            "Prediction: neutral\n",
            "Confidence: 0.42\n",
            "----------\n",
            "Input: That's how I see it\n",
            "Prediction: neutral\n",
            "Confidence: 0.53\n",
            "----------\n",
            "Input: These are all facts\n",
            "Prediction: neutral\n",
            "Confidence: 0.41\n",
            "----------\n",
            "Input: This is the worst thing\n",
            "Prediction: negative\n",
            "Confidence: 0.52\n",
            "----------\n",
            "Input: I cannot stand this any longer\n",
            "Prediction: negative\n",
            "Confidence: 0.52\n",
            "----------\n",
            "Input: This is really annoying\n",
            "Prediction: negative\n",
            "Confidence: 0.56\n",
            "----------\n",
            "Input: I am just plain fed up\n",
            "Prediction: negative\n",
            "Confidence: 0.57\n",
            "----------\n"
          ]
        }
      ],
      "source": [
        "# Classify the inputs\n",
        "predictions = classify_text(inputs,examples)\n",
        "\n",
        "# Display the classification outcomes\n",
        "classes = [\"positive\",\"negative\",\"neutral\"]\n",
        "for inp,pred in zip(inputs,predictions):\n",
        "  class_pred = pred.prediction\n",
        "  class_idx = classes.index(class_pred)\n",
        "  class_conf = pred.confidence\n",
        "\n",
        "  print(f\"Input: {inp}\")\n",
        "  print(f\"Prediction: {class_pred}\")\n",
        "  print(f\"Confidence: {class_conf:.2f}\")\n",
        "  print(\"-\"*10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "5XTtVliyNVyO"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.10.0 64-bit ('3.10.0')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    },
    "vscode": {
      "interpreter": {
        "hash": "1fb8019e3560b882083e525615cf48e713d3a7345a15eb723d805e91aa410aac"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
